# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
import sys
from typing import List, Optional, Union
 
import numpy as np
from auto_optimizer import OnnxGraph, OnnxNode
 
 
optimize_plans = {
    "vit_base_patch8_224": ["merge_bmm_axis", "pad_nz_block"],
    "vit_base_patch16_224": ["pad_nz_block"],
    "vit_base_patch16_384": ["merge_bmm_axis", "pad_nz_block"],
    "vit_base_patch32_224": ["merge_bmm_axis"],
    "vit_base_patch32_384": ["merge_bmm_axis", "pad_nz_block"],
}
 
 
def pattern_select(
    graph: OnnxGraph,
    candidate_nodes: Union[str, List[str]],
    preorders: Optional[List[str]] = None, 
    successors: Optional[List[str]] = None
) -> List[OnnxNode]:
    ret = []
    preorders = preorders or []
    successors = successors or []
    
    for node in candidate_nodes:
        pattern_check = True
        current_node = node
        for p in preorders[::-1]:
            if isinstance(p, str):
                op_type = p
                input_idx = 0
 
            elif isinstance(p, tuple):
                op_type, input_idx = p
 
            else:
                raise TypeError(f"Invalid preorder type: {type(p)}!")
 
            current_node = graph.get_prev_node(current_node.inputs[input_idx])
            if not current_node or current_node.op_type != op_type:
                pattern_check = False
                break
 
        if not pattern_check:
            continue
        
        current_node = node
        for s in successors:
            output_idx = 0
            if isinstance(s, str):
                op_type = s
 
            elif isinstance(s, tuple):
                op_type, output_idx = s
                
            else:
                raise TypeError(f"Invalid successor type: {type(s)}!")
 
            next_nodes = graph.get_next_nodes(current_node.outputs[output_idx])
            pattern_check = False
            for next_node in next_nodes:
                if next_node.op_type == op_type:
                    current_node = next_node
                    pattern_check = True
                    break
 
            if not pattern_check:
                break
 
        if pattern_check:
            ret.append(node)
    
    return ret
 
    
def get_attention_reshape_nodes(graph: OnnxGraph) -> List[OnnxNode]:
    # Pattern: Transpose -> [Reshape] -> MatMul
    all_reshape_nodes = graph.get_nodes("Reshape")
    return pattern_select(graph, all_reshape_nodes, ["Transpose"], ["MatMul"])


def get_first_layernorm_transpose_nodes(graph: OnnxGraph) -> List[OnnxNode]:
    # Pattern: Mul -> Add -> [Transpose]
    all_transpose_nodes = graph.get_nodes("Transpose")
    return pattern_select(graph, all_transpose_nodes, ["Mul", "Add"])


def get_last_layernorm_transpose_nodes(graph: OnnxGraph) -> List[OnnxNode]:
    # Pattern: [Transpose] -> Gather
    all_transpose_nodes = graph.get_nodes("Transpose")
    return pattern_select(graph, all_transpose_nodes, None, ["Gather"])
 
 
def get_layernorm_add_nodes(graph: OnnxGraph) -> List[OnnxNode]:
    # Pattern: Mul -> MatMul -> Add -> [Add]
    all_add_nodes = graph.get_nodes("Add")
    return pattern_select(graph, all_add_nodes, ["Mul", ("MatMul", 1), ("Add", 1)])
 
 
def get_layernorm_add_nodes_2(graph: OnnxGraph) -> List[OnnxNode]:
    # Pattern: Reshape -> MatMul -> Add -> [Add]
    all_add_nodes = graph.get_nodes("Add")
    return pattern_select(graph, all_add_nodes, ["Reshape", ("MatMul", 1), ("Add", 1)]) or \
        pattern_select(graph, all_add_nodes, ["Reshape", ("MatMul", 0), ("Add", 1)])


def get_attention_add_nodes_3(graph: OnnxGraph) -> List[OnnxNode]:
    # Pattern: [Add] -> Slice
    all_add_nodes = graph.get_nodes("Add")
    return pattern_select(graph, all_add_nodes, None, ["Slice"])

def get_layernorm_reducemean_nodes(graph: OnnxGraph) -> List[OnnxNode]:
    all_reducemean_nodes = graph.get_nodes("ReduceMean")
    return pattern_select(graph, all_reducemean_nodes, None, [("Sub", 0), "Div", "Mul", "Add", "MatMul", "Add", "Reshape", "Transpose", "Split"]) or \
        pattern_select(graph, all_reducemean_nodes, None, [("Sub", 1), "Div", "Mul", "Add", "MatMul", "Add", "Reshape", "Transpose", "Split"])


def merge_bmm_axis(graph: OnnxGraph, anchor_reshapes: List[OnnxNode], anchor_adds: List[OnnxNode]) -> None:
    reshape_inits = List(set(node.inputs[1] for node in anchor_reshapes))
    original_shape = graph[reshape_inits[0]].value
    original_shape_init = graph.add_initializer(f"original_shape", original_shape)
 
    # change the target shape of reshape operators
    for _init in reshape_inits:
        b, x, y = graph[_init].value
        graph[_init].value = np.array([b * x, y])
 
    first_add_node = graph.get_nodes("Add")[0]
    next_add_node = [node for node in graph.get_next_nodes(first_add_node.outputs[0]) if node.op_type == "Add"][0]
    
    new_reshape_name = f"Reshape_before_{next_add_node.name}"
    graph.add_node(
        new_reshape_name, 
        "Reshape", 
        inputs=[first_add_node.outputs[0], reshape_inits[0]],
        outputs=[f"{new_reshape_name}/{next_add_node.name}"],
    )
    next_add_node.inputs[0] = f"{new_reshape_name}/{next_add_node.name}"
 
    # Restore the original shape temporarily for operator fusion
    for add_node in anchor_adds:
        output_name = add_node.outputs[0]
        new_reshape_name = f"Reshape_after_{add_node.name}"
        graph.add_node(
            new_reshape_name, 
            "Reshape", 
            inputs=[output_name, original_shape_init.name],
            outputs=[f"{new_reshape_name}_output"],
        )
 
        for next_node in graph.get_next_nodes(output_name):
            if next_node.op_type in ["ReduceMean", "Sub"]:
                next_node.inputs[next_node.inputs.index(output_name)] = f"{new_reshape_name}_output"
 
    # Restore the original shape at the end
    gather_node = graph.get_nodes("Gather")[0]
    new_reshape_name_2 = f"Reshape_before_{gather_node.name}"
    new_reshape_node_2 = graph.add_node(new_reshape_name_2, "Reshape")
    graph.insert_node(gather_node.name, new_reshape_node_2, mode="before")
    new_reshape_node_2.inputs.append(original_shape_init.name)
 
 
def cal_padding_shape(graph: OnnxGraph, merged: bool=False) -> tuple:
    bs, c, w, h = graph.inputs[0].shape
    first_reshape = graph.get_nodes("Reshape")[0]
    _, hidden_dim1, _ = graph[first_reshape.inputs[1]].value
    hidden_dim2 = c * w * h // hidden_dim1 + 1
 
    if hidden_dim2 % 16 == 0:
        padding_size = 0
    else:
        padding_size = 16 - hidden_dim2 % 16
 
    if merged:
        return (bs * padding_size, hidden_dim1), (bs * hidden_dim2, hidden_dim1)
 
    return (bs, padding_size, hidden_dim1), (bs, hidden_dim2, hidden_dim1)
 
 
def pad_nz_block(
    graph: OnnxGraph, 
    anchor_reshapes: List[OnnxNode], 
    anchor_adds: List[OnnxNode], 
    anchor_adds_2: List[OnnxNode], 
    merged: bool=False
) -> None:
    padding_shape, original_shape = cal_padding_shape(graph, merged)
    axis = 0 if merged else 1
 
    new_concat_init = graph.add_initializer(f"padding_concat_init", np.zeros(padding_shape, dtype=np.float32))
    add_node = anchor_adds_2[0]
    new_concat_name = f"Concat_before_{add_node.name}"
    new_concat_node = graph.add_node(new_concat_name, "Concat", attrs={"axis": axis})
    graph.insert_node(add_node.name, new_concat_node, refer_index=0, mode="before")
    new_concat_node.inputs.append(new_concat_init.name)
    
    for reshape in anchor_reshapes:
        new_concat_name = f"Concat_after_{reshape.name}"
        new_concat_node = graph.add_node(new_concat_name, "Concat", attrs={"axis": axis})
        graph.insert_node(reshape.name, new_concat_node)
        new_concat_node.inputs.append(new_concat_init.name)
 
    for add_node in anchor_adds:
        output_name = add_node.outputs[0]
        new_slice_name = f"Slice_before_{add_node.name}"
        new_slice_init_starts = graph.add_initializer(f"{new_slice_name}_init_starts", np.array([0]))
        new_slice_init_ends = graph.add_initializer(f"{new_slice_name}_init_ends", np.array([original_shape[axis]]))
        new_slice_init_axes = graph.add_initializer(f"{new_slice_name}_init_axes", np.array([axis]))
        graph.add_node(
            new_slice_name, 
            "Slice",
            inputs=[output_name, new_slice_init_starts.name, new_slice_init_ends.name, new_slice_init_axes.name],
            outputs=[f"{new_slice_name}_output"],
            )
 
        for next_node in graph.get_next_nodes(output_name):
            if next_node.op_type in ["ReduceMean", "Sub", "Reshape"]:
                next_node.inputs[next_node.inputs.index(output_name)] = f"{new_slice_name}_output"


def convert_gemm_to_matmul_add(graph: OnnxGraph, anchor_gemms: List[OnnxNode]) -> None:
    """
    pattern:
                   Reshape(2-dims)                                 Reshape                                                  
                      |                                               |                                            
                    Gemm                                            MatMul                     
                      |                     =======>                  |
                   Reshape(3-dims)                                   Add
                      |                                               |
                     Add                                             Add
                                                   
    """
    for gemm_node in anchor_gemms:
        reshape_before_gemm = graph.get_prev_node(gemm_node.inputs[0])
        reshape_after_gemm = graph.get_next_nodes(gemm_node.outputs[0])[0]
        if reshape_before_gemm.op_type != "Reshape" or reshape_after_gemm.op_type != "Reshape":
            continue
        
        new_matmul_name = f"matmul_replace_{gemm_node.name}"
        graph.add_node(
            new_matmul_name,
            "MatMul",
            inputs = [reshape_before_gemm.outputs[0], gemm_node.inputs[1]],
            outputs = [f"{new_matmul_name}_out"]
        )

        new_add_name = f"add_replace_{gemm_node.name}"
        graph.add_node(
            new_add_name,
            "Add",
            inputs = [f"{new_matmul_name}_out", gemm_node.inputs[2]],
            outputs = [reshape_after_gemm.outputs[0]]
        )

        graph.remove(gemm_node.name, {})
        graph.remove(reshape_after_gemm.name, {})


def delete_transpose(graph: OnnxGraph, anchor_transposes: List[OnnxNode]) -> None:
    for transpose_node in anchor_transposes:
        graph.remove(transpose_node.name)


def adapt_for_layernormqkv(
    graph: OnnxGraph, 
    anchor_adds: List[OnnxNode],
    anchor_softmaxes: List[OnnxNode]
) -> None:
    """
    pattern:
                                                                                 LayerNorm
                                                                                     |
                                                                                   MatMul
                                                                                     |
                  LayerNorm                                                         Add
                      |                                                              |
                    MatMul                                                        Reshape
                      |                                                              |
                     Add                                                         Transpose
                 /    |     \                                                        |
                /     |      \                                                     Split     
               /      |       \                                                /     |     \
              /       |        \                    adapt                     /      |      \
        Slice_v    Slice_q    Slice_k              =======>                  /       |       \
            |         |          |                                          /        |        \
      Reshape_v   Reshape_q   Reshape_k                              Squeeze_v   Squeeze_q   Squeeze_k
            |         |          |                                         |          \          |
    Transpose_v  Transpose_q  Transpose_k                                  |           \     Transpose_k
            |         |         /                                          |            \       /
            |       Div        /                                           |            Div    /
            |          \      /                                             \             \   /
             \         MatMul                                                \            MatMul
              \           |                                                   \             |
               \       SoftMax                                                 \          SoftMax
                \        /                                                      \          /
                  MatMul                                                           MatMul
                    |                                                                |
                 Transpose                                                        Transpose
                    |                                                                |
                  Reshape                                                          Reshape

    """
    _, (bs, hidden_dim2, hidden_dim1) = cal_padding_shape(graph, merged=False)
    
    for add_node, softmax_node in zip(anchor_adds, anchor_softmaxes):

        matmul_before_softmax = graph.get_prev_node(softmax_node.inputs[0])
        matmul_after_softmax = graph.get_next_nodes(softmax_node.outputs[0])[0]
        div_before_matmul = graph.get_prev_node(matmul_before_softmax.inputs[0])
        transpose_after_matmul = graph.get_next_nodes(matmul_after_softmax.outputs[0])[0]
        reshape_after_transpose = graph.get_next_nodes(transpose_after_matmul.outputs[0])[0]

        new_reshape_name = f"Reshape_after_{add_node.name}"
        new_reshape_init = graph.add_initializer(
            f"{new_reshape_name}_init",
            np.array([bs, hidden_dim2, 3, 12, 64])
        )
        new_reshape_node = graph.add_node(
            new_reshape_name,
            "Reshape",
            inputs=[add_node.outputs[0], new_reshape_init.name],
            outputs=[f"{new_reshape_name}_output"]
        )

        new_transpose_name = f"Transpose_after_{new_reshape_name}"
        new_transpose_node = graph.add_node(
            new_transpose_name,
            "Transpose",
            inputs=[new_reshape_node.outputs[0]],
            outputs=[f"{new_transpose_name}_output"],
            attrs={"perm": [2,0,3,1,4]}
        )

        new_split_name = f"Split_after_{new_transpose_name}"
        new_split_node = graph.add_node(
            new_split_name,
            "Split",
            inputs=[new_transpose_node.outputs[0]],
            outputs=[f"{new_split_name}_output_q", f"{new_split_name}_output_k", f"{new_split_name}_output_v"],
            attrs={"axis": 0}
        )

        new_transpose_2_name = f"Transpose_after_{new_split_name}"
        new_transpose_2_node = graph.add_node(
            new_transpose_2_name,
            "Transpose",
            inputs=[f"{new_split_name}_output_k"],
            outputs=[f"{new_transpose_2_name}_output"],
            attrs={"perm": [0,1,3,2]}
        )

        div_before_matmul.inputs[0] = f"{new_split_name}_output_q"
        matmul_before_softmax.inputs[1] = f"{new_transpose_2_name}_output"
        matmul_after_softmax.inputs[1] = f"{new_split_name}_output_v"

        for idx in range(3):
            new_squeeze_name = f"Squeeze_after_{new_split_name}_{idx}"
            new_squeeze_node = graph.add_node(
                new_squeeze_name,
                "Squeeze",
                attrs = {"axes": [0]}
            )
            graph.insert_node(new_split_name, new_squeeze_node, refer_index=idx, mode="after")

        transpose_after_matmul["perm"] = [0,2,1,3]

        reshape_init = graph.add_initializer(
            f"{reshape_after_transpose}_init",
            np.array([bs, hidden_dim2, hidden_dim1])
        )
        reshape_after_transpose.inputs[1] = reshape_init.name
    
    graph.update_map()
    graph.remove_unused_nodes()


def adapt_for_flashattention(graph: OnnxGraph, anchor_softmaxes: List[OnnxNode]) -> None:
    """
    pattern:
            /         |          \                                          /        |          \
      Squeeze_v   Squeeze_q   Squeeze_k                              Squeeze_v   Squeeze_q   Squeeze_k
            |         |          |                                         |         |          |
            |         |       Transpose_k                                  |         |          |
            |         |         /                 adapt                    |         |          |
            |       Div        /                 =======>                  |        Div         |
            |          \      /                                             \        |         /
             \         MatMul                                                \       |        /
              \           |                                                   FlashAttentionTik
               \       SoftMax                                                       |
                \        /                                                       Transpose
                  MatMul                                                             |
                    |                                                             Reshape
                 Transpose                                                        
                    |                                                               
                  Reshape                                                          

    """
    fa_name = "FlashAttentionTik"
    for softmax_node in anchor_softmaxes:
        matmul_before_softmax = graph.get_prev_node(softmax_node.inputs[0])
        matmul_after_softmax = graph.get_next_nodes(softmax_node.outputs[0])[0]
        transpose_before_matmul = graph.get_prev_node(matmul_before_softmax.inputs[1])

        new_node = graph.add_node(softmax_node.name.replace("Softmax", fa_name), fa_name)
        new_node.inputs = [matmul_before_softmax.inputs[0], transpose_before_matmul.inputs[0], matmul_after_softmax.inputs[1]]
        new_node.outputs = matmul_after_softmax.outputs

        graph.remove(matmul_before_softmax.name, {})
        graph.remove(softmax_node.name, {})
        graph.remove(transpose_before_matmul.name, {})
        graph.remove(matmul_after_softmax.name, {})


def adapt_for_attentionscore(graph: OnnxGraph, anchor_softmaxes: List[OnnxNode]) -> None:
    """
    pattern:
            /         |          \                                          /        |          \
      Squeeze_v   Squeeze_q   Squeeze_k                              Squeeze_v   Squeeze_q   Squeeze_k
            |         |          |                                         |         |          |
            |         |       Transpose_k                                  |         |       Transpose_k
            |         |         /                 adapt                    |         |          /
            |       Div        /                 =======>                  |        Muls       /
            |          \      /                                             \          \      /
             \         MatMul                                                \          MatMul
              \           |                                                   \           |
               \       SoftMax                                                 \        SoftMax
                \        /                                                      \        /
                  MatMul                                                          MatMul
                    |                                                               |
                 Transpose                                                       Transpose
                    |                                                               | 
                  Reshape                                                         Reshape

    """
    for softmax_node in anchor_softmaxes:
        matmul_before_softmax = graph.get_prev_node(softmax_node.inputs[0])
        div_before_matmul = graph.get_prev_node(matmul_before_softmax.inputs[0])

        new_mul_name = f"Mul_before_{softmax_node.name}"
        new_mul_node = graph.add_node(new_mul_name, "Mul")
        div_value = graph[div_before_matmul.inputs[1]].value
        new_mul_init = graph.add_initializer(
            f"{new_mul_name}_init",
            np.array(1/div_value, dtype="float32")
        )
        graph.insert_node(softmax_node.name, new_mul_node, mode="before")
        new_mul_node.inputs.append(new_mul_init.name)


def split_layernormqkv_and_flashattention(
    graph: OnnxGraph, 
    anchor_reducemeans: List[OnnxNode],
    anchor_splits: List[OnnxNode],
    anchor_flashattentions: List[OnnxNode]
) -> None:
    """
    If batch size is 20, split layernormqkv and flashattention with 4bs and 16bs.
    """
    for reducemean, split, flashattention in zip(anchor_reducemeans, anchor_splits, anchor_flashattentions):
        # extract subgraph for copy
        graph.extract_subgraph(
            [reducemean.name],
            [flashattention.name],
            "./temp.onnx"
        )
        temp_graph = OnnxGraph.parse("temp.onnx")
        for node in temp_graph.nodes:
            node.name += "_copy"
            node.inputs = [inp + "_copy" for inp in node.inputs]
            node.outputs = [out + "_copy" for out in node.outputs]
        for initializer in temp_graph.initializers:
            initializer.name += "_copy"
        temp_graph.update_map()

        # get nodes
        sub = graph.get_next_nodes(reducemean.outputs[0])[0]
        transpose = graph.get_prev_node(split.inputs[0])
        reshape = graph.get_prev_node(transpose.inputs[0])
        
        temp_reducemean = temp_graph.get_nodes("ReduceMean")[0]
        temp_sub = temp_graph.get_next_nodes(temp_reducemean.outputs[0])[0]
        temp_split = temp_graph.get_nodes("Split")[0]
        temp_transpose = temp_graph.get_prev_node(temp_split.inputs[0])
        temp_reshape = temp_graph.get_prev_node(temp_transpose.inputs[0])
        temp_flashattention = temp_graph.get_nodes("FlashAttentionTik")[0]

        # create new Slice node
        new_slice_4bs_name = f"Slice_4bs_before_{reducemean.name}"
        new_slice_4bs_init_starts = graph.add_initializer(f"{new_slice_4bs_name}_init_starts", np.array([0]))
        new_slice_4bs_init_ends = graph.add_initializer(f"{new_slice_4bs_name}_init_ends", np.array([4]))
        new_slice_4bs_init_axes = graph.add_initializer(f"{new_slice_4bs_name}_init_axes", np.array([0]))
        graph.add_node(
            new_slice_4bs_name, 
            "Slice",
            inputs=[reducemean.inputs[0], new_slice_4bs_init_starts.name, \
                    new_slice_4bs_init_ends.name, new_slice_4bs_init_axes.name],
            outputs=[f"{new_slice_4bs_name}_output"],
            )

        new_slice_16bs_name = f"Slice_16bs_before_{reducemean.name}"
        new_slice_16bs_init_starts = graph.add_initializer(f"{new_slice_16bs_name}_init_starts", np.array([4]))
        new_slice_16bs_init_ends = graph.add_initializer(f"{new_slice_16bs_name}_init_ends", np.array([20]))
        new_slice_16bs_init_axes = graph.add_initializer(f"{new_slice_16bs_name}_init_axes", np.array([0]))
        graph.add_node(
            new_slice_16bs_name, 
            "Slice",
            inputs=[reducemean.inputs[0], new_slice_16bs_init_starts.name, \
                    new_slice_16bs_init_ends.name, new_slice_16bs_init_axes.name],
            outputs=[f"{new_slice_16bs_name}_output"],
            )
        
        reducemean.inputs = [f"{new_slice_4bs_name}_output"]
        sub.inputs = [f"{new_slice_4bs_name}_output", reducemean.outputs[0]]
        temp_reducemean.inputs = [f"{new_slice_16bs_name}_output"]
        temp_sub.inputs[0] = f"{new_slice_16bs_name}_output"

        # create new Concat node
        new_concat_name = f"Concat_after_{flashattention.name}"
        new_concat_node = graph.add_node(
            new_concat_name,
            "Concat",
            attrs={"axis":0}
        )
        graph.insert_node(flashattention.name, new_concat_node, mode="after")
        new_concat_node.inputs.append(f"{new_concat_name}_input")
        temp_flashattention.outputs = [f"{new_concat_name}_input"]

        # fix reshape for bs4
        reshape_init = graph[reshape.inputs[1]]
        shape = reshape_init.value
        reshape_init.value = np.hstack((np.array([4]), shape[1:]))

        temp_reshape_init = temp_graph[temp_reshape.inputs[1]]
        temp_shape = temp_reshape_init.value
        temp_reshape_init.value = np.hstack((np.array([16]), shape[1:]))

        # merge subgraph
        graph.nodes.extend(temp_graph.nodes)
        graph.initializers.extend(temp_graph.initializers)
        graph.initializers = list(set(graph.initializers))
        graph.update_map()

 
def apply_optimization(opts) -> None:
    
    plan = optimize_plans.get(opts.model_config)
    merged_axis = False
    g = OnnxGraph.parse(opts.input_file)
    bs = g.inputs[0].shape[0]

    supported_bs = [1, 4, 8, 16, 24, 32, 40, 48, 56, 64, 20]
    if bs not in supported_bs:
        raise ValueError(f"The batch size is {bs}. Only support batch size {supported_bs}")
    else:
        g = OnnxGraph.parse(opts.input_file)
        gemms = g.get_nodes("Gemm")
        softmaxes = g.get_nodes("Softmax")
        if gemms:
            convert_gemm_to_matmul_add(g, gemms)
        reshapes = get_attention_reshape_nodes(g)
        transposes = get_first_layernorm_transpose_nodes(g) + \
            get_last_layernorm_transpose_nodes(g)
        adds = get_layernorm_add_nodes(g)
        adds_2 = get_layernorm_add_nodes_2(g)
        adds_3 = get_attention_add_nodes_3(g)
        
        delete_transpose(g, transposes)
        adapt_for_layernormqkv(g, adds_3, softmaxes)
        if opts.use_flashattention:
            adapt_for_flashattention(g, softmaxes)
        else:
            adapt_for_attentionscore(g, softmaxes)
    
        for opt in plan:
            if opt == "merge_bmm_axis":
                merge_bmm_axis(g, reshapes, adds)
                merged_axis = True
    
            elif opt == "pad_nz_block":
                pad_nz_block(g, reshapes, adds, adds_2, merged_axis)
 
    if bs == 20 and args.use_flashattention:
        reducemeans = get_layernorm_reducemean_nodes(g)
        splits = g.get_nodes("Split")
        flashattentions = g.get_nodes("FlashAttentionTik")
        split_layernormqkv_and_flashattention(g, reducemeans, splits, flashattentions)
    
    g.infer_shape()
    g.save(opts.output_file)
 
 
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="img onnx modification")
    parser.add_argument("--input_file", type=str, required=True,
                        help="path to input onnx")
    parser.add_argument("--output_file", type=str, required=True,
                        help="path to output onnx")
    parser.add_argument("--model_config", type=str, required=True,
                        help="model_config of ViT", 
                        choices=["vit_base_patch8_224", "vit_base_patch16_224", \
                                 "vit_base_patch16_384", "vit_base_patch32_224", "vit_base_patch32_384"])
    parser.add_argument("--use_flashattention", action="store_true",
                        help="If passed, use flashattention.")
    args = parser.parse_args()
 
    apply_optimization(args)